import torch
import torch.nn as nn
import torch.nn.functional as F

class CaMIB(nn.Module):
    def __init__(self, input_dim=512, dim=512, output_dim=256, beta=1e-3):
        super(CaMIB, self).__init__()
        # MIB
        self.fc_t = nn.Linear(input_dim, dim)
        self.fc_mu_t  = nn.Linear(dim, dim) 
        self.fc_std_t = nn.Linear(dim, dim)
        #
        self.fc_v = nn.Linear(input_dim, dim)
        self.fc_mu_v  = nn.Linear(dim, dim) 
        self.fc_std_v = nn.Linear(dim, dim)
        #
        self.fc_a = nn.Linear(input_dim, dim)
        self.fc_mu_a  = nn.Linear(dim, dim) 
        self.fc_std_a = nn.Linear(dim, dim)
        #
        self.decoder_t = nn.Linear(dim, output_dim)
        self.decoder_v = nn.Linear(dim, output_dim)
        self.decoder_a = nn.Linear(dim, output_dim)
        #
        self.decoder = nn.Linear(dim*3, dim)
        self.b = beta

        # Causal
        self.att_mlp = nn.Linear(dim, 2)
        self.decoder_c = nn.Linear(dim, output_dim)
        self.decoder_o= nn.Linear(dim, output_dim)

        # Instrumental Variable
        self.W_Q = nn.Linear(dim, dim)
        self.W_K = nn.Linear(dim, dim)
        self.W_V = nn.Linear(dim, dim)

        self.criterion = nn.MSELoss()
        
    def encode_t(self, x):
        return self.fc_mu_t(x), F.softplus(self.fc_std_t(x)-5, beta=1)
    
    def encode_v(self, x):
        return self.fc_mu_v(x), F.softplus(self.fc_std_v(x)-5, beta=1)
    
    def encode_a(self, x):
        return self.fc_mu_a(x), F.softplus(self.fc_std_a(x)-5, beta=1)
    
    def reparameterise(self, mu, std):
        eps = torch.randn_like(std)
        return mu + std*eps

    def forward(self, t, v, a):
        out_t = self.fc_t(t)
        mu_t, std_t = self.encode_t(out_t)
        out_t = self.reparameterise(mu_t, std_t)
        output_t = self.decoder_t(out_t)
        KL_t = 0.5 * torch.mean(mu_t.pow(2) + std_t.pow(2) - 2*std_t.log() - 1)
        #
        out_v = self.fc_v(v)
        mu_v, std_v = self.encode_v(out_v)
        out_v = self.reparameterise(mu_v, std_v)
        output_v = self.decoder_v(out_v)
        KL_v = 0.5 * torch.mean(mu_v.pow(2) + std_v.pow(2) - 2*std_v.log() - 1)
        #
        out_a = self.fc_a(a)
        mu_a, std_a = self.encode_a(out_a)
        out_a = self.reparameterise(mu_a, std_a)
        output_a = self.decoder_a(out_a)
        KL_a = 0.5 * torch.mean(mu_a.pow(2) + std_a.pow(2) - 2*std_a.log() - 1)
        #
        out = torch.cat((out_t, out_v, out_a), dim=2)
        output = self.decoder(out)
          
        multimodal_features = torch.cat([out_t.unsqueeze(1), out_v.unsqueeze(1), out_a.unsqueeze(1)], dim=1)
        B, M, L, D = multimodal_features.shape
        #
        Q = self.W_Q(multimodal_features).reshape(B, M*L, D) 
        K = self.W_K(multimodal_features).reshape(B, M*L, D)
        V = self.W_V(multimodal_features).reshape(B, M*L, D)
        #
        attention_scores = torch.bmm(Q, K.transpose(1, 2)) / (D ** 0.5)
        attention_weights = F.softmax(attention_scores, dim=-1)
        V = torch.bmm(attention_weights, V)
        V = V.view(B, M, L, D)
        V = V.sum(dim=1)

        att = F.softmax(self.att_mlp(output), dim=-1)
        output_c = att[..., 0:1] * output  # trivial
        output_o = att[..., 1:2] * output  # causal

        align_loss = self.criterion(output_o, V)

        output_c = self.decoder_c(output_c)
        output_o = self.decoder_o(output_o)
        
        KL = self.b * (KL_t + KL_v + KL_a)
        return output_t, output_v, output_a, output_c, output_o, align_loss, KL

